{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-PKa6W4wdPWr"
      },
      "outputs": [],
      "source": [
        "# Copyright 2024 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "idM_aIPheQDG"
      },
      "source": [
        "# Stage 2: Building MVP: - 03 RAG\n",
        "\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/workshops/rag-ops/2.3_mvp_rag.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fworkshops%2Frag-ops%2F2.3_mvp_rag.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>    \n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/workshops/rag-ops/2.3_mvp_rag.ipynb\">\n",
        "      <img src=\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\" alt=\"Vertex AI logo\"><br> Open in Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/workshops/rag-ops/2.3_mvp_rag.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/workshops/rag-ops/2.3_mvp_rag.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/workshops/rag-ops/2.3_mvp_rag.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/workshops/rag-ops/2.3_mvp_rag.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/workshops/rag-ops/2.3_mvp_rag.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/workshops/rag-ops/2.3_mvp_rag.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>            "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o0oOZgLKfhCi"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This notebook is the third in a series designed to guide you through building a Minimum Viable Product (MVP) for a Multimodal Retrieval Augmented Generation (RAG) system using the Vertex Gemini API.\n",
        "\n",
        "Building on the foundation of text chunking and embedding generation from the previous notebooks, we now bring those pieces together to construct the core retrieval and generation components of our RAG system. This notebook demonstrates a practical implementation of these components using a ground truth dataset, showcasing how to effectively retrieve relevant information and generate accurate answers.\n",
        "\n",
        "**Here's what you'll achieve:**\n",
        "\n",
        "* **Build a Retrieval System:**  Develop a robust retrieval system that leverages semantic search and Vertex AI Embeddings to efficiently identify the most pertinent text chunks for any given query. This system will utilize the vector database created in the previous notebook to perform accurate similarity searches.\n",
        "* **Implement Generation:**  Construct a generation component that utilizes the retrieved information to produce comprehensive and informative answers. This process involves selecting the top N results from the retrieved chunks and feeding them to Gemini 2.0 and Flash models for generating the final answer.\n",
        "* **Generate Citations:**  Incorporate functionality to generate citations for the answers produced by your RAG system. This ensures transparency and traceability of the information provided, enhancing the reliability and trustworthiness of your MVP.\n",
        "* **Test with Ground Truth Data:**  Apply your complete RAG system to a ground truth dataset, observing its performance on real-world questions and answers. This exercise provides valuable insights into the effectiveness of your retrieval and generation pipeline.\n",
        "\n",
        "This notebook bridges the gap between individual components and a functional RAG system. By implementing retrieval and generation mechanisms and testing them on ground truth data, you gain a practical understanding of how to build an end-to-end solution for answering questions based on multimodal sources.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6KGP8kNhfklW"
      },
      "source": [
        "## Getting Started"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HGFDxQ7_flui"
      },
      "source": [
        "### Install Vertex AI SDK for Python\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "YcWmeELUeTPq"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --user --quiet google-cloud-aiplatform"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SYYzbNlmfvKJ"
      },
      "source": [
        "### Restart runtime\n",
        "\n",
        "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "icVKaoLAfw6c"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "if \"google.colab\" in sys.modules:\n",
        "    import IPython\n",
        "\n",
        "    app = IPython.Application.instance()\n",
        "    app.kernel.do_shutdown(True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oZoekSN8fy2E"
      },
      "source": [
        "<div class=\"alert alert-block alert-warning\">\n",
        "<b>⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️</b>\n",
        "</div>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E4rlvfF1f1RA"
      },
      "source": [
        "### Authenticate your notebook environment (Colab only)\n",
        "\n",
        "If you are running this notebook on Google Colab, run the cell below to authenticate your environment.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "XTaUAqKLf3N8"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "if \"google.colab\" in sys.modules:\n",
        "    from google.colab import auth\n",
        "\n",
        "    auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iQJQgZKXf6Mx"
      },
      "source": [
        "### Set Google Cloud project information, GCS Bucket and initialize Vertex AI SDK\n",
        "\n",
        "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
        "\n",
        "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "FqXU2Ptvf5LA"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "\n",
        "from google.cloud import storage\n",
        "import vertexai\n",
        "\n",
        "PROJECT_ID = \"[your-project-id]\"  # @param {type:\"string\"}\n",
        "LOCATION = \"us-central1\"\n",
        "BUCKET_NAME = \"mlops-for-genai\"\n",
        "\n",
        "if PROJECT_ID == \"[your-project-id]\":\n",
        "    PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
        "\n",
        "if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\" or PROJECT_ID == \"None\":\n",
        "    raise ValueError(\"Please set your PROJECT_ID\")\n",
        "\n",
        "\n",
        "vertexai.init(project=PROJECT_ID, location=LOCATION)\n",
        "\n",
        "# Initialize cloud storage\n",
        "storage_client = storage.Client(project=PROJECT_ID)\n",
        "bucket = storage_client.bucket(BUCKET_NAME)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "6AbGiKORhxQ7"
      },
      "outputs": [],
      "source": [
        "# # Variables for data location. Do not change.\n",
        "\n",
        "PRODUCTION_DATA = \"multimodal-finanace-qa/data/unstructured/production/\"\n",
        "PICKLE_FILE_NAME = \"index_db.pkl\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s9-h_WOcgAKX"
      },
      "source": [
        "### Import libraries\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "OaW7NsbHgAoo"
      },
      "outputs": [],
      "source": [
        "# Library\n",
        "\n",
        "import pickle\n",
        "\n",
        "from google.cloud import storage\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "from rich import print as rich_print\n",
        "from vertexai.generative_models import GenerativeModel, HarmBlockThreshold, HarmCategory\n",
        "from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A9hij3nhgDz8"
      },
      "source": [
        "### Load the Gemini 2.0 models\n",
        "\n",
        "To learn more about all [Gemini API models on Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models).\n",
        "\n",
        "The Gemini model family has several model versions. You will start by using Gemini 2.0. Gemini 2.0 is a more lightweight, fast, and cost-efficient model. This makes it a great option for prototyping.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "7xvFkagTgHMc"
      },
      "outputs": [],
      "source": [
        "MODEL_ID_FLASH = \"gemini-2.0-flash\"  # @param {type:\"string\"}\n",
        "MODEL_ID_PRO = \"gemini-2.0-flash\"  # @param {type:\"string\"}\n",
        "\n",
        "\n",
        "gemini_15_flash = GenerativeModel(MODEL_ID_FLASH)\n",
        "gemini_15_pro = GenerativeModel(MODEL_ID_PRO)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "cellView": "form",
        "id": "oS1ar31xYmkR"
      },
      "outputs": [],
      "source": [
        "# @title Helper Functions\n",
        "\n",
        "\n",
        "def get_load_dataframes_from_gcs():\n",
        "    gcs_path = \"multimodal-finanace-qa/data/embeddings/\" + PICKLE_FILE_NAME\n",
        "    # print(\"GCS PAth: \", gcs_path)\n",
        "    blob = bucket.blob(gcs_path)\n",
        "\n",
        "    # Download the pickle file from GCS\n",
        "    blob.download_to_filename(f\"{PICKLE_FILE_NAME}\")\n",
        "\n",
        "    # Load the pickle file into a list of dataframes\n",
        "    with open(f\"{PICKLE_FILE_NAME}\", \"rb\") as f:\n",
        "        dataframes = pickle.load(f)\n",
        "\n",
        "    # Assign the dataframes to variables\n",
        "    (\n",
        "        index_db_final,\n",
        "        extracted_text_chunk_df,\n",
        "        video_metadata_chunk_df,\n",
        "        audio_metadata_chunk_df,\n",
        "    ) = dataframes\n",
        "\n",
        "    return (\n",
        "        index_db_final,\n",
        "        extracted_text_chunk_df,\n",
        "        video_metadata_chunk_df,\n",
        "        audio_metadata_chunk_df,\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KKT4gg2ybwaR"
      },
      "source": [
        "![](https://storage.googleapis.com/mlops-for-genai/multimodal-finanace-qa/img/rag_eval_flow.png)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "Jk5387FCjlAv"
      },
      "outputs": [],
      "source": [
        "# Get the data that has been extracted in the previous step: IndexDB.\n",
        "# Make sure that you have ran the previous notebook: stage_2_mvp_chunk_embeddings.ipynb\n",
        "\n",
        "\n",
        "(\n",
        "    index_db_final,\n",
        "    extracted_text_chunk_df,\n",
        "    video_metadata_chunk_df,\n",
        "    audio_metadata_chunk_df,\n",
        ") = get_load_dataframes_from_gcs()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "isgFDZBZyZ_1"
      },
      "outputs": [],
      "source": [
        "index_db_final.head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "OqdsKp2UyxHu"
      },
      "outputs": [],
      "source": [
        "index_db_final.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "8FITeQNFyuSw"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "# load training data\n",
        "training_data = pd.read_csv(\n",
        "    \"gs://mlops-for-genai/multimodal-finanace-qa/data/structured/training_data_subset.csv\"\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "C0wICVtD2wEh"
      },
      "outputs": [],
      "source": [
        "training_data.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rJEiZPXyy2lE"
      },
      "outputs": [],
      "source": [
        "training_data.head(2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5nBY_it821Rr"
      },
      "outputs": [],
      "source": [
        "training_data[\"question_type\"].value_counts()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "q-U7b7ccy6K_"
      },
      "outputs": [],
      "source": [
        "index = 8\n",
        "\n",
        "print(\"*******The question: *******\\n\")\n",
        "rich_print(training_data[\"question\"][index])\n",
        "print(\"\\n*******The ground-truth answer:*******\")\n",
        "rich_print(training_data[\"answer\"][index])\n",
        "print(\"\\n*******The question type: *******\\n\", training_data[\"question_type\"][index])\n",
        "print(\n",
        "    \"*******The question type description: *******\\n\",\n",
        "    training_data[\"question_type_description\"][index],\n",
        ")\n",
        "print(\n",
        "    \"*******Text citation: *******\\n\",\n",
        ")\n",
        "rich_print(training_data[\"text_citation\"][index])\n",
        "print(\n",
        "    \"*******Audio citation: *******\\n\",\n",
        ")\n",
        "rich_print(training_data[\"audio_citation\"][index])\n",
        "print(\n",
        "    \"*******Video citation: *******\\n\",\n",
        ")\n",
        "rich_print(training_data[\"video_citation\"][index])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TIHuezEHzAHv"
      },
      "source": [
        "## Retrieval"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "cellView": "form",
        "id": "IxkYu1SGy_4r"
      },
      "outputs": [],
      "source": [
        "# @title Retrieval Functions\n",
        "\n",
        "import numpy as np\n",
        "from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel\n",
        "\n",
        "\n",
        "def get_gemini_response(\n",
        "    model,\n",
        "    generation_config=None,\n",
        "    safety_settings=None,\n",
        "    uri_path=None,\n",
        "    mime_type=None,\n",
        "    prompt=None,\n",
        "):\n",
        "    if not generation_config:\n",
        "        generation_config = {\n",
        "            \"max_output_tokens\": 8192,\n",
        "            \"temperature\": 1,\n",
        "            \"top_p\": 0.95,\n",
        "        }\n",
        "\n",
        "    if not safety_settings:\n",
        "        safety_settings = {\n",
        "            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,\n",
        "            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,\n",
        "            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,\n",
        "            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,\n",
        "        }\n",
        "\n",
        "    responses = model.generate_content(\n",
        "        prompt,\n",
        "        generation_config=generation_config,\n",
        "        safety_settings=safety_settings,\n",
        "        stream=True,\n",
        "    )\n",
        "    final_response = []\n",
        "    for response in responses:\n",
        "        try:\n",
        "            final_response.append(response.text)\n",
        "        except ValueError:\n",
        "            # print(\"Something is blocked...\")\n",
        "            final_response.append(\"blocked\")\n",
        "\n",
        "    return \"\".join(final_response)\n",
        "\n",
        "\n",
        "def get_text_embeddings(\n",
        "    texts: list[str] = [\"banana muffins? \", \"banana bread? banana muffins?\"],\n",
        "    task: str = \"RETRIEVAL_DOCUMENT\",\n",
        "    model_name: str = \"text-embedding-005\",\n",
        ") -> list[list[float]]:\n",
        "    # print(\"doing...\")\n",
        "    \"\"\"Embeds texts with a pre-trained, foundational model.\"\"\"\n",
        "    model = TextEmbeddingModel.from_pretrained(model_name)\n",
        "    inputs = [TextEmbeddingInput(text, task) for text in texts]\n",
        "    embeddings = model.get_embeddings(inputs)\n",
        "    return [embedding.values for embedding in embeddings][0]\n",
        "\n",
        "\n",
        "def get_cosine_score(\n",
        "    dataframe: pd.DataFrame, column_name: str, input_text_embed: np.ndarray\n",
        ") -> float:\n",
        "    \"\"\"\n",
        "    Calculates the cosine similarity between the user query embedding and the dataframe embedding for a specific column.\n",
        "\n",
        "    Args:\n",
        "        dataframe: The pandas DataFrame containing the data to compare against.\n",
        "        column_name: The name of the column containing the embeddings to compare with.\n",
        "        input_text_embed: The NumPy array representing the user query embedding.\n",
        "\n",
        "    Returns:\n",
        "        The cosine similarity score (rounded to two decimal places) between the user query embedding and the dataframe embedding.\n",
        "    \"\"\"\n",
        "    if dataframe[column_name]:\n",
        "        text_cosine_score = round(np.dot(dataframe[column_name], input_text_embed), 2)\n",
        "        return text_cosine_score\n",
        "    else:\n",
        "        return 0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IaV51_LKzB-i"
      },
      "source": [
        "## Generation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "cellView": "form",
        "id": "lKcBneQXzDlP"
      },
      "outputs": [],
      "source": [
        "# @title Generation Functions\n",
        "\n",
        "\n",
        "def get_answer(row, vector_db, model, top_n=5):\n",
        "    query_embedding = get_text_embeddings([row[\"question\"]])\n",
        "    # Find score\n",
        "    cosine_scores = vector_db.apply(\n",
        "        lambda x: get_cosine_score(x, \"embeddings\", query_embedding),\n",
        "        axis=1,\n",
        "    )\n",
        "    # print(len(cosine_scores))\n",
        "    # Remove same image comparison score when user image is matched exactly with metadata image\n",
        "    # cosine_scores = cosine_scores[cosine_scores < 1.00000000]\n",
        "    # Get top N cosine scores and their indices\n",
        "    top_n_cosine_scores = cosine_scores.nlargest(top_n).index.tolist()\n",
        "    top_n_cosine_values = cosine_scores.nlargest(top_n).values.tolist()\n",
        "\n",
        "    citations = vector_db.iloc[top_n_cosine_scores].copy()\n",
        "    # citations['score'] = top_n_cosine_scores\n",
        "    citations.loc[:, \"score\"] = top_n_cosine_values\n",
        "    citations = citations[[\"uid\", \"type\", \"content\", \"score\"]]\n",
        "\n",
        "    # # print(citations)\n",
        "    # gemini_content = get_gemini_content_list(query, vector_db, top_n_cosine_scores)\n",
        "    context = \"\\n\".join(citations[\"content\"].tolist())\n",
        "    prompt = f\"\"\"Task: Answer the question based on the provided context.\n",
        "\n",
        "Guidelines:\n",
        "\n",
        "1. **Chain of Thought:**  Before generating the final answer, break down the question into smaller steps and reason through them logically.\n",
        "2. **Conciseness:** If the question asks for specific information, provide only those values.\n",
        "3. **Critical Thinking:** Analyze the entire context thoroughly and critically evaluate the information before formulating your answer. Do not rush to conclusions.\n",
        "4. **Structure:** Present your answer in a clear and organized manner. It should include everything that the question is asking.\n",
        "5. **Answer Only:**  Do not include the chain of thought reasoning in the final answer. The answer should be concise.\n",
        "\n",
        "Question: {row['question']}\n",
        "\n",
        "Context: {context}\n",
        "\n",
        "Answer:\n",
        "\"\"\"\n",
        "    response = get_gemini_response(model=model, prompt=prompt)\n",
        "    type_mod = np.unique(\n",
        "        [each_cit[\"type\"] for each_cit in citations.to_dict(\"records\")]\n",
        "    )\n",
        "    return (response, context, type_mod, citations.to_dict(\"records\"), prompt)\n",
        "\n",
        "\n",
        "def get_gen_answer_citation(\n",
        "    training_data,\n",
        "    index,\n",
        "    ref_data=[\n",
        "        extracted_text_chunk_df,\n",
        "        video_metadata_chunk_df,\n",
        "        audio_metadata_chunk_df,\n",
        "    ],\n",
        "    top_n=None,\n",
        "    gcs_full_path=True,\n",
        "):\n",
        "    gen_citation = training_data[\"citation\"][index]\n",
        "    # gen_citation_uids = [each_el['uid'] for each_el in gen_citation]\n",
        "    final_citations = []\n",
        "    for each_cit in gen_citation:\n",
        "        if each_cit[\"type\"] == \"text\":\n",
        "            text_citations = ref_data[0][ref_data[0][\"uid\"].isin([each_cit[\"uid\"]])]\n",
        "            final_citations.append(\n",
        "                {\n",
        "                    \"score\": each_cit[\"score\"],\n",
        "                    \"uid\": each_cit[\"uid\"],\n",
        "                    \"type\": each_cit[\"type\"],\n",
        "                    \"content\": text_citations[\"text\"].values[0],\n",
        "                    \"gcs_path\": (\n",
        "                        text_citations[\"gcs_path\"].values[0]\n",
        "                        if gcs_full_path\n",
        "                        else text_citations[\"gcs_path\"].values[0].split(\"/\")[-1]\n",
        "                    ),\n",
        "                }\n",
        "            )\n",
        "        elif each_cit[\"type\"] == \"video\":\n",
        "            video_citations = ref_data[1][ref_data[1][\"uid\"].isin([each_cit[\"uid\"]])]\n",
        "            final_citations.append(\n",
        "                {\n",
        "                    \"score\": each_cit[\"score\"],\n",
        "                    \"uid\": each_cit[\"uid\"],\n",
        "                    \"type\": each_cit[\"type\"],\n",
        "                    \"content\": video_citations[\"video_description\"].values[0],\n",
        "                    \"gcs_path\": (\n",
        "                        video_citations[\"video_gcs\"].values[0]\n",
        "                        if gcs_full_path\n",
        "                        else video_citations[\"video_gcs\"].values[0].split(\"/\")[-1]\n",
        "                    ),\n",
        "                }\n",
        "            )\n",
        "\n",
        "        elif each_cit[\"type\"] == \"audio\":\n",
        "            audio_citations = ref_data[2][ref_data[2][\"uid\"].isin([each_cit[\"uid\"]])]\n",
        "            final_citations.append(\n",
        "                {\n",
        "                    \"score\": each_cit[\"score\"],\n",
        "                    \"uid\": each_cit[\"uid\"],\n",
        "                    \"type\": each_cit[\"type\"],\n",
        "                    \"content\": audio_citations[\"audio_description\"].values[0],\n",
        "                    \"gcs_path\": (\n",
        "                        audio_citations[\"audio_gcs\"].values[0]\n",
        "                        if gcs_full_path\n",
        "                        else audio_citations[\"audio_gcs\"].values[0].split(\"/\")[-1]\n",
        "                    ),\n",
        "                }\n",
        "            )\n",
        "    if top_n:\n",
        "        return pd.DataFrame(final_citations[:top_n])\n",
        "    else:\n",
        "        return pd.DataFrame(final_citations)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SaHhgpQ9zeah"
      },
      "source": [
        "## Answer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0LEuCE_A1uOb"
      },
      "source": [
        "if you are getting: InternalServerError: 500 Internal error encountered. Try switching between 001 or 002 models. There might be a capacity issue. You can also re-run the cell after few minutes.\n",
        "\n",
        "`gemini-2.0-flash` or `gemini-2.0-flash`\n",
        "\n",
        "`gemini-2.0-flash` or `gemini-2.0-flash`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "XApvnEByzevQ"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "training_data_flash = training_data.copy()\n",
        "training_data_flash[[\"gen_answer\", \"context\", \"type_mode\", \"citation\", \"prompt\"]] = (\n",
        "    training_data.apply(\n",
        "        lambda x: get_answer(x, index_db_final, gemini_15_flash, top_n=100),\n",
        "        axis=1,\n",
        "        result_type=\"expand\",\n",
        "    )\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "Iog6ZTX5zib-"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "training_data_pro = training_data.copy()\n",
        "training_data_pro[[\"gen_answer\", \"context\", \"type_mode\", \"citation\", \"prompt\"]] = (\n",
        "    training_data.apply(\n",
        "        lambda x: get_answer(x, index_db_final, gemini_15_pro, top_n=100),\n",
        "        axis=1,\n",
        "        result_type=\"expand\",\n",
        "    )\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lcAS2zj4_KYR"
      },
      "source": [
        "### Answer with Gemini 2.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "wevnMsQ_2LAI"
      },
      "outputs": [],
      "source": [
        "training_data_flash[[\"question\", \"answer\", \"gen_answer\"]].head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "Elpc3_yP2qOv"
      },
      "outputs": [],
      "source": [
        "index = 4\n",
        "\n",
        "print(\"*******The question: *******\\n\")\n",
        "rich_print(training_data_flash[\"question\"][index])\n",
        "\n",
        "print(\"\\n*******The ground-truth answer:*******\\n\")\n",
        "rich_print(training_data_flash[\"answer\"][index])\n",
        "\n",
        "print(\"\\n*******The generated answer: *******\\n\")\n",
        "rich_print(training_data_flash[\"gen_answer\"][index])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "BJBf3iqH8Pdy"
      },
      "outputs": [],
      "source": [
        "get_gen_answer_citation(training_data_flash, index, top_n=5, gcs_full_path=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yPbpxpqL_OWn"
      },
      "source": [
        "### Answer with Gemini 2.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "id": "MK5OdaFa2WTw"
      },
      "outputs": [],
      "source": [
        "training_data_pro[[\"question\", \"answer\", \"gen_answer\"]].head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "id": "KzH70mZy3fVz"
      },
      "outputs": [],
      "source": [
        "print(\"*******The question: *******\\n\")\n",
        "rich_print(training_data_pro[\"question\"][index])\n",
        "\n",
        "print(\"\\n*******The ground-truth answer:*******\\n\")\n",
        "rich_print(training_data_pro[\"answer\"][index])\n",
        "\n",
        "print(\"\\n*******The generated answer: *******\\n\")\n",
        "rich_print(training_data_pro[\"gen_answer\"][index])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "id": "rcMqmmPM--VL"
      },
      "outputs": [],
      "source": [
        "get_gen_answer_citation(training_data_pro, index, top_n=5, gcs_full_path=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "t8-OQWUmKIbn"
      },
      "outputs": [],
      "source": [
        "training_data_pro.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oyH6gsU91TRG"
      },
      "source": [
        "### Save the intermediate Files"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "metadata": {
        "id": "i6VI5v_R1TlX"
      },
      "outputs": [],
      "source": [
        "# # [Optional]\n",
        "# import pickle\n",
        "\n",
        "# pickle_file_name =\"training_data_results.pkl\"\n",
        "# data_to_dump = [training_data_flash, training_data_pro]\n",
        "\n",
        "# gcs_location = f\"gs://mlops-for-genai/multimodal-finanace-qa/data/structured/{pickle_file_name}\"\n",
        "\n",
        "# with open(f\"{pickle_file_name}\", \"wb\") as f:\n",
        "#     pickle.dump(data_to_dump, f)\n",
        "\n",
        "\n",
        "# # Upload the pickle file to GCS\n",
        "# !gsutil cp {pickle_file_name} {gcs_location}"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "2.3_mvp_rag.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
